package org.hepeng.workx.spring.security.web.filter;

import org.apache.commons.codec.binary.Base64;
import org.apache.commons.lang3.StringUtils;
import org.hepeng.workx.web.util.HttpRequestUtils;
import org.joor.Reflect;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.redis.serializer.JdkSerializationRedisSerializer;
import org.springframework.data.redis.serializer.RedisSerializer;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.web.servletapi.SecurityContextHolderAwareRequestWrapper;
import org.springframework.util.Assert;

import javax.servlet.http.HttpServletRequest;
import java.util.Objects;

/**
 * @author he peng
 */
public class UpstreamRequestSkipOverSpringSecurityChainFilter extends SkipOverSpringSecurityFilterChainFilter {

    public static final String SECURITY_CONTEXT_HEADER_NAME = "x-forwarded-security-context";
    private RedisSerializer redisSerializer = new JdkSerializationRedisSerializer();

//    private ObjectSerializer<SecurityContext> objectSerializer;

    {
        registrySecurityContextHolderStrategy();
    }

    public UpstreamRequestSkipOverSpringSecurityChainFilter() {
//        this.objectSerializer = ObjectSerializationUtils.newObjectSerializer(SupportSerializer.HESSIAN , SecurityContext.class);
    }

    /*public UpstreamRequestSkipOverSpringSecurityChainFilter(ObjectSerializer objectSerializer) {
        this.objectSerializer = objectSerializer;
    }*/

    @Override
    protected boolean isSkipOver(HttpServletRequest request) {
        return isUpstreamRequest(request);
    }

    private boolean isUpstreamRequest(HttpServletRequest request) {
        String securityContextHeader = request.getHeader(SECURITY_CONTEXT_HEADER_NAME);
        return StringUtils.isNotBlank(securityContextHeader) ? true : false;
    }

    protected SecurityContext getUpstreamSecurityContext(HttpServletRequest request) {
        String securityContextVal = request.getHeader(SECURITY_CONTEXT_HEADER_NAME);
        if (StringUtils.isBlank(securityContextVal)) {
            return null;
        }

        byte[] bytes = Base64.decodeBase64(securityContextVal);
        SecurityContext securityContext =
                (SecurityContext) UpstreamRequestSkipOverSpringSecurityChainFilter
                        .this.redisSerializer.deserialize(bytes);
        return securityContext;
    }

    @Override
    protected HttpServletRequest wrapRequest(HttpServletRequest request) {
        return new SecurityContextHolderAwareRequestWrapper(request , "ROLE_");
    }

    private void registrySecurityContextHolderStrategy() {
        Reflect reflect = Reflect.on(SecurityContextHolder.class);
        SecurityContextHolderStrategy strategy = reflect.get("strategy");
        reflect.set("strategy" , new UpstreamSecurityContextHolderStrategy(strategy));
    }

    private class UpstreamSecurityContextHolderStrategy implements SecurityContextHolderStrategy {

        private final Logger LOG = LoggerFactory.getLogger(UpstreamSecurityContextHolderStrategy.class);

        private SecurityContextHolderStrategy delegate;

        public UpstreamSecurityContextHolderStrategy(SecurityContextHolderStrategy strategy) {
            if (Objects.nonNull(strategy)) {
                this.delegate = strategy;
            } else {
                Reflect reflect = Reflect.on("org.springframework.security.core.context.ThreadLocalSecurityContextHolderStrategy");
                this.delegate = reflect.create().get();
                Assert.notNull(this.delegate , "delegate must not be null");
            }
        }

        @Override
        public void clearContext() {
            this.delegate.clearContext();
        }

        @Override
        public SecurityContext getContext() {
            HttpServletRequest httpServletRequest = HttpRequestUtils.getHttpServletRequest();
            boolean isUpstreamRequest = UpstreamRequestSkipOverSpringSecurityChainFilter.this.isUpstreamRequest(httpServletRequest);
            if (isUpstreamRequest) {
                SecurityContext securityContext =
                        UpstreamRequestSkipOverSpringSecurityChainFilter.this.getUpstreamSecurityContext(httpServletRequest);
                LOG.info("current http request comes from the api gateway , SecurityContext {} " , securityContext);
                return securityContext;
            }

            return this.delegate.getContext();
        }

        @Override
        public void setContext(SecurityContext context) {
            this.delegate.setContext(context);
        }

        @Override
        public SecurityContext createEmptyContext() {
            return this.delegate.createEmptyContext();
        }
    }
}
